package database

import (
	"du/config"
	"fmt"
	"math/rand"
	"sync"

	"gorm.io/driver/mysql"
	"gorm.io/gorm"
)

type Database struct {
	dbMaster *gorm.DB
	dbSlaves []*gorm.DB
}

var once sync.Once
var dbInstance *Database

func GetDB() *Database {
	once.Do(func() {
		dbInstance = &Database{}
	})
	return dbInstance
}

func InitDB() (*Database, error) {
	masterDsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
		config.App.Mysql.Master.Username,
		config.App.Mysql.Master.Password,
		config.App.Mysql.Master.Host,
		config.App.Mysql.Master.Port,
		config.App.Mysql.Master.Database)

	dbMaster, err := gorm.Open(mysql.Open(masterDsn), &gorm.Config{})
	if err != nil {
		return nil, fmt.Errorf("failed to connect to master database: %v", err)
	}

	// 初始化从库连接
	dbSlaves := make([]*gorm.DB, len(config.App.Mysql.Slave))
	for i, slaveCfg := range config.App.Mysql.Slave {
		slaveDsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
			slaveCfg.Username,
			slaveCfg.Password,
			slaveCfg.Host,
			slaveCfg.Port,
			slaveCfg.Database)
		dbSlave, err := gorm.Open(mysql.Open(slaveDsn), &gorm.Config{})
		if err != nil {
			return nil, fmt.Errorf("failed to connect to slave database %d: %v", i, err)
		}
		dbSlaves[i] = dbSlave
	}

	// 构建数据库实例
	dbInstance = GetDB()
	dbInstance.dbMaster = dbMaster
	dbInstance.dbSlaves = dbSlaves

	// 返回主库连接
	return dbInstance, nil
}

func (d *Database) GetMasterDB() *gorm.DB {
	return dbInstance.dbMaster
}

func (d *Database) GetSlaveDB() *gorm.DB {
	if len(dbInstance.dbSlaves) == 0 {
		return dbInstance.dbMaster
	}
	return dbInstance.dbSlaves[rand.Intn(len(dbInstance.dbSlaves))]
}
